# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
import warnings
from functools import partial

import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from torch import nn

from mmcls.models import build_classifier

torch.manual_seed(3)

try:
    import coremltools as ct
except ImportError:
    raise ImportError('Please install coremltools to enable output file.')


def _demo_mm_inputs(input_shape: tuple, num_classes: int):
    """Create a superset of inputs needed to run test or train batches.

    Args:
        input_shape (tuple):
            input batch dimensions
        num_classes (int):
            number of semantic classes
    """
    (N, C, H, W) = input_shape
    rng = np.random.RandomState(0)
    imgs = rng.rand(*input_shape)
    gt_labels = rng.randint(
        low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
    mm_inputs = {
        'imgs': torch.FloatTensor(imgs).requires_grad_(False),
        'gt_labels': torch.LongTensor(gt_labels),
    }
    return mm_inputs


def pytorch2mlmodel(model: nn.Module, input_shape: tuple, output_file: str,
                    add_norm: bool, norm: dict):
    """Export Pytorch model to mlmodel format that can be deployed in apple
    devices through torch.jit.trace and the coremltools library.

       Optionally, embed the normalization step as a layer to the model.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input_shape (tuple): Use this input shape to construct
            the corresponding dummy input and execute the model.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output
            TorchScript model.
        add_norm (bool): Whether to embed the normalization layer to the
            output model.
        norm (dict): image normalization config for embedding it as a layer
            to the output model.
    """
    model.cpu().eval()

    num_classes = model.head.num_classes
    mm_inputs = _demo_mm_inputs(input_shape, num_classes)

    imgs = mm_inputs.pop('imgs')
    img_list = [img[None, :] for img in imgs]
    model.forward = partial(model.forward, img_metas={}, return_loss=False)

    with torch.no_grad():
        trace_model = torch.jit.trace(model, img_list[0])
        save_dir, _ = osp.split(output_file)
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)

        if add_norm:
            means, stds = norm.mean, norm.std
            if stds.count(stds[0]) != len(stds):
                warnings.warn(f'Image std from config is {stds}. However, '
                              'current version of coremltools (5.1) uses a '
                              'global std rather than the channel-specific '
                              'values that torchvision uses. A mean will be '
                              'taken but this might tamper with the resulting '
                              'model\'s predictions. For more details refer '
                              'to the coreml docs on ImageType pre-processing')
                scale = np.mean(stds)
            else:
                scale = stds[0]

            bias = [-mean / scale for mean in means]
            image_input = ct.ImageType(
                name='input_1',
                shape=input_shape,
                scale=1 / scale,
                bias=bias,
                color_layout='RGB',
                channel_first=True)

            coreml_model = ct.convert(trace_model, inputs=[image_input])
            coreml_model.save(output_file)
        else:
            coreml_model = ct.convert(
                trace_model, inputs=[ct.TensorType(shape=input_shape)])
            coreml_model.save(output_file)

        print(f'Successfully exported coreml model: {output_file}')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMCls to MlModel format for apple devices')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--checkpoint', help='checkpoint file', type=str)
    parser.add_argument('--output-file', type=str, default='model.mlmodel')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[224, 224],
        help='input image size')
    parser.add_argument(
        '--add-norm-layer',
        action='store_true',
        help='embed normalization layer to deployed model')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (
            1,
            3,
        ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None

    # build the model and load checkpoint
    classifier = build_classifier(cfg.model)

    if args.checkpoint:
        load_checkpoint(classifier, args.checkpoint, map_location='cpu')

    # convert model to mlmodel file
    pytorch2mlmodel(
        classifier,
        input_shape,
        output_file=args.output_file,
        add_norm=args.add_norm_layer,
        norm=cfg.img_norm_cfg)
